Skip to content

Enable non-dim-0 FSDP sharding of MoE experts when ep=1#2668

Open
aws-ritikadm wants to merge 1 commit intopytorch:mainfrom
aws-ritikadm:ep1-fsdp-sharding
Open

Enable non-dim-0 FSDP sharding of MoE experts when ep=1#2668
aws-ritikadm wants to merge 1 commit intopytorch:mainfrom
aws-ritikadm:ep1-fsdp-sharding

Conversation

@aws-ritikadm
Copy link

Summary

Previously, routed experts in MoE layers were only separately wrapped with fully_shard when ep_degree > 1. When ep_degree == 1, experts were sharded only as part of the outer TransformerBlock FSDP group, which meant the Shard(1) placement optimization (sharding on hidden_dim instead of num_experts) was never applied.

This PR extends the separate expert FSDP wrapping to also apply when ep_degree == 1. When the FSDP degree exceeds num_experts, experts are sharded on dim 1 (hidden_dim) to avoid padding inefficiency from dim-0 sharding — the same optimization that was already in place for ep > 1.

Validation

python -m pytest tests/unit_tests/test_fsdp_moe_sharding.py -v

collecting ... collected 3 items
tests/unit_tests/test_fsdp_moe_sharding.py::TestApplyFsdpMoESharding::test_no_ep_fsdp_gt_num_experts_shards_dim1 PASSED [ 33%]
tests/unit_tests/test_fsdp_moe_sharding.py::TestApplyFsdpMoESharding::test_no_ep_fsdp_le_num_experts_shards_dim0 PASSED [ 66%]
tests/unit_tests/test_fsdp_moe_sharding.py::TestApplyFsdpMoESharding::test_with_ep_fsdp_gt_num_experts_shards_dim1 PASSED [100%]

The three tests:

Test Setup What it checks
test_no_ep_fsdp_gt_num_experts_shards_dim1 ep=1, 4 experts, 8 FSDP ranks 8 > 4 → experts sharded on dim 1. This is the new code path that this change enables.
test_no_ep_fsdp_le_num_experts_shards_dim0 ep=1, 8 experts, 8 FSDP ranks 8 == 8 → no padding issue, experts sharded on dim 0 (default). Also exercises the new else branch but without triggering Shard(1).
test_with_ep_fsdp_gt_num_experts_shards_dim1 ep=2, 4 experts, edp mesh [efsdp=4, ep=2] 4*2=8 > 4 → experts sharded on dim 1. This is the pre-existing EP path, included for regression coverage.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 23, 2026
@aws-ritikadm aws-ritikadm marked this pull request as ready for review March 23, 2026 07:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant